import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from k_means import KMeans

data_set = pd.read_csv('https://www.gairuo.com/file/data/dataset/iris.data')
print(data_set)
iris_types = data_set['species'].unique()
print(iris_types)
x1_axis = 'petal_length'
x2_axis = 'petal_width'

x_train = data_set[[x1_axis,x2_axis]].values.reshape(data_set.shape[0],2)

K = 3
max_iterations = 50
km = KMeans(x_train,K)
centroids,closet_idxs = km.train(max_iterations)

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
for type in iris_types:
    plt.scatter(
        x=data_set[x1_axis][data_set['species']==type],
        y=data_set[x2_axis][data_set['species']==type],
        label=type,
    )
plt.legend()
plt.title('label known')
plt.subplot(1,2,2)
for i,type in enumerate(iris_types):
    closet_samples_idx = closet_idxs==i
    plt.scatter(
        x=data_set[x1_axis][closet_samples_idx.flatten()],
        y=data_set[x2_axis][closet_samples_idx.flatten()],
        label=type,
    )
plt.title('kmeans')
plt.show()